//
// Created by song on 16-11-18.
//

#ifndef C0COMPILER_SEMANTIC_H
#define C0COMPILER_SEMANTIC_H

#include "ASTNode.h"
#include "SymbolTable.h"
#include "Lexer.h"
#include "ASTVisitor.h"



/**
 * this class should be used when symbolTable is ready.
 */
class PrintAST: public ASTVisitor{
private:
    int level;
    map<TokenType, string> typeStr;
public:
    PrintAST():
            level(-1){
        typeStr[Add]=         "+";
        typeStr[Minus]=       "-";
        typeStr[Mul]=         "*";
        typeStr[Div]=         "/";
        typeStr[LessThan]=    "<";
        typeStr[LessEqual]=   "<=";
        typeStr[GreatThan]=   ">";
        typeStr[GreatEqual]=  ">=";
        typeStr[NotEqual]=    "!=";
        typeStr[Equal]=       "==";
    }
protected:
    bool defaultAction = true;
    bool visit(ASTNode& node, bool isPostVisit){
        if(!isPostVisit){
            level++;
            for(int i=0;i<level;i++){
                cout<<"|---";
            }
            cout<< node.toString();
            ASTVisitor::visit(node, isPostVisit);
            if(node.needCheck){
                cout<<'~';
            }
            cout<<endl;
        }else{
            level--;
        }
        return defaultAction;
    }

    virtual bool constDeclar(ASTNode&node) override {
        cout<<"[" << node.content << "]("<< node.id <<")";
        return ASTVisitor::constDeclar(node);
    }

    virtual bool varDeclar(ASTNode &node) override {
        cout<<"[" << node.content << "]("<< node.id <<")";
        return ASTVisitor::varDeclar(node);
    }

    virtual bool arrayDeclar(ASTNode &node) override {
        cout<<"[" << node.content << "]("<< node.id <<")";
        return ASTVisitor::arrayDeclar(node);
    }

    virtual bool funDeclar(ASTNode &node, bool b) override {
        cout<<'['<<node.content<<']'<<"("<< node.id <<")";
        return ASTVisitor::funDeclar(node, b);
    }

    virtual bool param(ASTNode &node) override {
        cout<<"[" << node.content << "]("<< node.id <<")";
        return ASTVisitor::param(node);
    }

    virtual bool functionInvoke(ASTNode &node, bool b) override {
        cout<<'['<<node.content<<']'<<"("<< node.id <<")";
        return ASTVisitor::functionInvoke(node, false);
    }

    virtual bool arraySubscript(ASTNode &node, bool b) override {
        cout<< "[" << node.content << "]("<< node.id <<")";
        return ASTVisitor::arraySubscript(node, false);
    }

    virtual bool identifier(ASTNode &node) override {
        cout<<"[" << node.content << "]("<< node.id <<")";
        return ASTVisitor::identifier(node);
    }

    virtual bool number(ASTNode &node) override {
        cout<< "[" << node.value << "]";
        return ASTVisitor::number(node);
    }

    virtual bool charConst(ASTNode &node) override {
        cout << "['" << node.ch << "']";
        return ASTVisitor::charConst(node);
    }

    virtual bool stringConst(ASTNode &node) override {
        cout << "[" << node.content << "]";
        return ASTVisitor::stringConst(node);
    }

    virtual bool relationExpr(ASTNode &node, bool postVisit) override {
        cout << "[" << typeStr[node.contentType] << "]";
        return ASTVisitor::relationExpr(node, postVisit);
    }

    virtual bool addSubExpr(ASTNode &node, bool postVisit) override {
        cout << "[" << typeStr[node.contentType] << "]";
        return ASTVisitor::addSubExpr(node, postVisit);
    }

    virtual bool multiDivExpr(ASTNode &node, bool postVisit) override {
        cout << "[" << typeStr[node.contentType] << "]";
        return ASTVisitor::multiDivExpr(node, postVisit);
    }

    virtual bool caseStmt(ASTNode &node, bool postVisit) override {
        if(node.contentType==Char){
            cout << "['" << node.ch << "']";
        }else{
            cout << "[" << node.value << "]";
        }
        return ASTVisitor::caseStmt(node, postVisit);
    }
};


class BuildSymbolTable: public ASTVisitor{
private:
    Lexer& lexer;
    SymbolTable& symTable;

public:
    BuildSymbolTable(Lexer& lexer, SymbolTable& table):
            lexer(lexer),
            symTable(table)
    {
        this->defaultAction = true;
    }

private:

    void semanticErr(Error::ErrorNo no, ASTNode& node){
        Error::nextErrorDetail << endl
                               << lexer.charIndex2pos(node.start) << endl
                               << lexer.pointPos(node.start, node.end);
        Error::semantic(no);
    }

    virtual bool constDeclar(ASTNode &node) override {
        int id = symTable.lookUpVar(node.content, true);
        if(id==-1){// not found, insert.
            if(node.contentType==Char){
                node.id = symTable.enterConst(node.content, node.contentType, node.ch);
            }else{
                node.id = symTable.enterConst(node.content, node.contentType, node.value);
            }
        }else{// duplicate name.
            Error::nextErrorDetail << node.content;
            semanticErr(Error::Duplicate_Var_Name, node);
        }
        return false;
    }

    virtual bool varDeclar(ASTNode &node) override {
        int id = symTable.lookUpVar(node.content, true);
        if(id==-1) {// not found, insert.
            node.id = symTable.enterVar(node.content, node.contentType);
        }else{
            Error::nextErrorDetail << node.content;
            semanticErr(Error::Duplicate_Var_Name, node);
        }
        return false;
    }

    virtual bool arrayDeclar(ASTNode &node) override {
        int id = symTable.lookUpVar(node.content, true);
        if(id==-1) {// not found, insert.
            node.id = symTable.enterArray(node.content, node.contentType, node.value);
        }else{
            Error::nextErrorDetail << node.content;
            semanticErr(Error::Duplicate_Var_Name, node);
        }
        return false;
    }

    virtual bool funDeclar(ASTNode &node, bool isPostVisit) override {
        if(!isPostVisit){
            int id = symTable.lookUpFun(node.content);
            if(id==-1) {// not found, insert.
                vector<TokenType> arguments;
                vector<ASTNode *>::const_iterator iter;
                for (iter = node.children.begin();
                     iter != node.children.end(); iter++) {
                    if ((*iter)->astType == Param) {
                        arguments.push_back((*iter)->contentType);
                    } else {
                        break;
                    }
                }
                node.id = symTable.enterFunction(node.content, node.contentType, arguments);
                symTable.goIntoFunction();
            }else{
                Error::nextErrorDetail << node.content;
                semanticErr(Error::Duplicate_Function_Name, node);
            }
        }else{
            symTable.getOutFunction();
        }
        return true;
    }

    virtual bool param(ASTNode &node) override {
        int id = symTable.lookUpVar(node.content, true);
        if(id==-1) { // not found, insert.
            node.id = symTable.enterVar(node.content, node.contentType);
        }else{
            Error::nextErrorDetail << node.content;
            semanticErr(Error::Duplicate_Var_Name, node);
        }
        return false;
    }

    virtual bool functionInvoke(ASTNode &node, bool b) override {
        int id = symTable.lookUpFun(node.content);
        if(id==-1){
            Error::nextErrorDetail << node.content;
            semanticErr(Error::Function_Declar_Not_Found, node);
        }
        node.id = id;
        return true;
    }

    virtual bool arraySubscript(ASTNode &node, bool b) override {
        int id = symTable.lookUpVar(node.content, false);
        if(id==-1){
            Error::nextErrorDetail << node.content;
            semanticErr(Error::Array_Declar_Not_Found, node);
        }
        if(!symTable.getVar(id)->isArray){
            Error::nextErrorDetail << node.content;
            semanticErr(Error::Array_Declar_Not_Found, node);
        }
        node.id = id;
        return true;
    }

    // find and link variables.
    virtual bool identifier(ASTNode &node) override {
        int id = symTable.lookUpVar(node.content, false);
        if(id==-1){
            Error::nextErrorDetail << node.content;
            semanticErr(Error::Identifier_Declar_Not_Found, node);
        }
        node.id = id;
        return false;
    }

};



/**
     * check list:
     * 0. evaluate expression to get its type.
     * 1. function call
     *     -- argument count & type match --- done
     *     -- void function in expression --- done
     * 2. function returns declared type.(check function body) --- done
     * 3. assignment type match & const & array name.  --- done
     */
class TypeChecker: public ASTVisitor{
private:
    Lexer& lexer;
    SymbolTable& symTable;
    SymbolTable::Type funReturnType;
    bool funHasReturnStmt;

public:
    TypeChecker(Lexer& lexer, SymbolTable table):
            lexer(lexer),
            symTable(table)
    {
        this->defaultAction = true;
        this->postVisitSet = {
            FunctionInvoke,
            AddSubExpr,
            MultiDivExpr,
            AssignStmt,
            FunctionDeclar,
            ReturnStmt
        };
    }

private:
    void semanticErr(Error::ErrorNo no, ASTNode& node){
        Error::nextErrorDetail << endl
                               << lexer.charIndex2pos(node.start) << endl
                               << lexer.pointPos(node.start, node.end);
        Error::semantic(no);
    }

    SymbolTable::Type typeEval(ASTNode& node){
        SymbolTable::VarRecord* v;
        SymbolTable::FunRecord* f;
        switch(node.astType){
            case ArraySubscript:
            case Identifier:
                v = symTable.getVar(node.id);
                return v->type;
            case FunctionInvoke:
                f = symTable.getFun(node.id);
                return f->returnType;
            case AddSubExpr:
            case MultiDivExpr:
            case Number:
                return SymbolTable::INT;
            case CharConst:
                return SymbolTable::CHAR;
            case ConditionExpr:
            case StringConst:
            case RelationExpr:
            case SourceCode:
            case ConstDeclar:
            case VariableDeclar:
            case ArrayDeclar:
            case FunctionDeclar:
            case Param:
            case IfStmt:
            case WhileStmt:
            case SwitchStmt:
            case CaseSegment:
            case DefaultSegment:
            case AssignStmt:
            case ReturnStmt:
                semanticErr(Error::Should_Not_Happen, node);

        }
    }

    bool needRuntimeCheck(ASTNode *node){
        int v;
        if(node->astType==Number){ // only number or case label can be checked static
            v = node->value;
        }else if(node->astType==CaseSegment) {
            if (node->contentType == Int) {
                v = node->value;
            } else {
                v = node->ch;
            }
        }else{
            return true;
        }
        if( v=='+' || v=='-' || v=='*' || v=='/' ||
            ('0'<=v && v<='9')||
            ('A'<=v && v<='Z')||('a'<=v && v<='z')){
            return false; // safe conversion. no need to check at runtime.
        }else{
            semanticErr(Error::Number_Can_Not_Convert_To_Char_Const, *node);
        }
        return true; // need check at runtime;
    }

protected:
    virtual bool functionInvoke(ASTNode &node, bool postVisit) override {
        if(postVisit){
            SymbolTable::FunRecord* f = symTable.getFun(node.id);
            if(node.id >= 2){
                int argIndex = 0;
                vector<ASTNode *>::const_iterator iter;
                for (iter = node.children.begin();
                     iter != node.children.end(); iter++) {
                    SymbolTable::Type t = typeEval(**iter);
                    if(argIndex < f->argCount){
                        if(t==f->argSignature[argIndex]) {
                            //ok
                        }else {
                            if(t==SymbolTable::INT){
                                if(needRuntimeCheck(*iter)){
                                    (*iter)->needCheck = true;
                                }
                            }
//                            semanticErr(Error::Function_Invoke_Type_Mismatch, **iter);
                        }
                        argIndex++;
                    }else{
                        semanticErr(Error::Function_Invoke_Arg_Too_Much, **iter);
                    }
                }
                if(argIndex < f->argCount){
                    semanticErr(Error::Function_Invoke_Arg_Too_Few, node);
                }
            }else{
                //system function.
                if(f->name=="scanf"){
                    for(int i=0; i<node.children.size(); i++){
                        ASTNode* child = node.children[i];
                        SymbolTable::VarRecord* v = symTable.getVar(child->id);
                        if(v->isConst){
                            semanticErr(Error::Const_Can_Not_Be_Assigned, *child);
                        }
                        if(v->isArray){
                            Error::nextErrorDetail << (*child).content;
                            semanticErr(Error::Array_Name_Can_Not_In_Expr, *child);
                        }
                    }
                }
            }
            // check if can be void.
            switch(node.parent->astType){
                case SourceCode:
                case Number:
                case Identifier:
                case CharConst:
                case StringConst:
                case ConstDeclar:
                case VariableDeclar:
                case ArrayDeclar:
                case Param:
                    semanticErr(Error::Should_Not_Happen, node);
                    break;
                case FunctionDeclar: // can be void.
                case SwitchStmt:
                case CaseSegment:
                case DefaultSegment:
                case IfStmt:
                case WhileStmt:
                    // do nothing.
                    break;
                case ConditionExpr:
                case AssignStmt:
                case ReturnStmt:
                case RelationExpr:
                case AddSubExpr:
                case MultiDivExpr:
                case ArraySubscript:
                case FunctionInvoke:
                    // in expression.
                    if(f->returnType==SymbolTable::VOID){
                        semanticErr(Error::Call_Void_Function_In_Expr, node);
                    }
                    break;
            }
        }
        return true;
    }

    virtual bool assignStmt(ASTNode &node, bool postVisit) override {
        if(postVisit){
            ASTNode* leftValue = node.children[0];
            ASTNode* rightValue = node.children[1];
            SymbolTable::VarRecord* v = symTable.getVar(leftValue->id);
            if(leftValue->astType==Identifier){
                if(v->isConst){
                    semanticErr(Error::Const_Can_Not_Be_Assigned, *leftValue);
                }
                if(v->isArray){
                    semanticErr(Error::Array_Name_Can_Not_Be_Assigned, *leftValue);
                }
            }else if(leftValue->astType==ArraySubscript){

            }else{
                semanticErr(Error::Should_Not_Happen, node);
            }

            if(typeEval(*leftValue) != typeEval(*rightValue) && typeEval(*leftValue)==SymbolTable::CHAR){
                if(needRuntimeCheck(rightValue)){
                    rightValue->needCheck = true;
                }
//                semanticErr(Error::Assignment_Type_Mismatch, node);
            }
        }
        return true;
    }

    virtual bool funDeclar(ASTNode& node, bool postVisit) override {
        SymbolTable::FunRecord* f = symTable.getFun(node.id);
        this->funReturnType = f->returnType;
        if(f->returnType!=SymbolTable::VOID){
            if(!postVisit){
                this->funHasReturnStmt = false;
            }else{
                if(!this->funHasReturnStmt){
                    semanticErr(Error::Function_Need_Return, node);
                }
            }
        }
        return true;
    }

    virtual bool returnStmt(ASTNode &node, bool postVisit) override {
        if(postVisit){
            this->funHasReturnStmt = true;
            if(this->funReturnType==SymbolTable::VOID){
                if( ! node.isLeaf() ){
                    semanticErr(Error::Function_Return_Type_Mismatch, node);
                }else{
                    // ok, do nothing.
                }
            }else{
                if( node.isLeaf() ){
                    semanticErr(Error::Function_Return_Type_Mismatch, node);
                }else{
                    if(typeEval(*(node.children[0]))!=this->funReturnType){
                        if(this->funReturnType==SymbolTable::CHAR ){
                            if(needRuntimeCheck(node.children[0])){
                                node.children[0]->needCheck = true;
                            }
                        }else{
                            // no need type convert.
                        }
//                        semanticErr(Error::Function_Return_Type_Mismatch, node);
                    }else{
                        // OK, do nothing.
                    }
                }
            }
        }
        return true;
    }

    // result is same if we do not check. because a
    virtual bool caseStmt(ASTNode &node, bool postVisit) override {
        if(!postVisit){ // check when condition is char but case is int
            ASTNode* condition = node.parent->children[0]->children[0];
            if(typeEval(*condition)==SymbolTable::CHAR){
                needRuntimeCheck(&node);
            }else{
                //ok
            }
        }
        return true;
    }

};




class Semantic{

public:
    Lexer &lexer;
    ASTNode &ast;
    SymbolTable* symTable;
    Semantic(Lexer &lexer, ASTNode &ast) :
            lexer(lexer),
            ast(ast)
    {
        symTable = new SymbolTable();
    }

    SymbolTable* buildSymbolTable() {
        symTable->enterFunction("printf",Void,{});
        symTable->enterFunction("scanf",Void,{});
        BuildSymbolTable visitor(lexer,*symTable);
        visitor.startTraverse(ast);
        return this->symTable;
    }

    void check() {
        TypeChecker typeChecker(lexer, *symTable);
        typeChecker.startTraverse(ast);
    }

};
#endif //C0COMPILER_SEMANTIC_H
